{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'3.1.0'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import triton\n",
    "import triton.backends\n",
    "import triton.language as tl\n",
    "import math\n",
    "import os\n",
    "from flash_mla import get_mla_metadata, flash_mla_with_kvcache\n",
    "from mla import triton_mqa\n",
    "from flash_attn_interface import flash_attn_with_kvcache\n",
    "os.environ['TRITON_PRINT_AUTOTUNING'] = '1'\n",
    "triton.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.cuda.empty_cache()\n",
    "device = torch.device('cuda')\n",
    "dtype = torch.bfloat16\n",
    "bs, num_head, q_len, kv_len, rope_head_dim, nope_head_dim=32, 128, 1, 8192, 64, 128\n",
    "kv_lora_rank = 512\n",
    "scale = (rope_head_dim + nope_head_dim) ** (-0.5)\n",
    "q = torch.randn(bs, num_head, q_len, kv_lora_rank+rope_head_dim, device=device, dtype=dtype)\n",
    "k = torch.randn(bs, 1, kv_len, kv_lora_rank+rope_head_dim, device=device, dtype=dtype)\n",
    "v = torch.randn(bs, 1, kv_len, kv_lora_rank, device=device, dtype=dtype)\n",
    "k[..., :kv_lora_rank] = v\n",
    "# attention_mask = torch.ones(bs, kv_len, device=device, dtype=torch.int32)\n",
    "# # attention_mask[:, :10] = 0\n",
    "# attention_mask[:, :100] = 0\n",
    "\n",
    "cache_seqlens = torch.full((bs,), kv_len, dtype=torch.int32, device=device)\n",
    "total_seqlens = cache_seqlens.sum().item()\n",
    "mean_seqlens = cache_seqlens.float().mean().int().item()\n",
    "max_seqlen = cache_seqlens.max().item()\n",
    "max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256\n",
    "block_size = 64\n",
    "block_table = torch.arange(bs * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(bs, max_seqlen_pad // block_size)\n",
    "blocked_k = k.view(block_table.numel(), block_size, 1, kv_lora_rank+rope_head_dim)\n",
    "# for i in range(bs):\n",
    "#     blocked_k.view(bs, max_seqlen_pad, 1, kv_lora_rank+rope_head_dim)[i, cache_seqlens[i].item():] = float(\"nan\")\n",
    "# blocked_v = blocked_k[..., :kv_lora_rank]\n",
    "tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, q_len * num_head // 1, 1)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "out1 = triton_mqa(q, k, v, scale, attention_mask=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "out2, lse = flash_mla_with_kvcache(\n",
    "            q, blocked_k, block_table, cache_seqlens, kv_lora_rank,\n",
    "            tile_scheduler_metadata, num_splits, softmax_scale=scale, causal=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.5333253145217896"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "triton.testing.do_bench(lambda: triton_mqa(q, k, v, scale, attention_mask=None))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.15365096926689148"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "triton.testing.do_bench(lambda: flash_mla_with_kvcache(\n",
    "            q, blocked_k, block_table, cache_seqlens, kv_lora_rank,\n",
    "            tile_scheduler_metadata, num_splits, softmax_scale=scale, causal=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@torch.inference_mode()\n",
    "def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen):\n",
    "    print(f\"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {varlen=}\")\n",
    "\n",
    "    cache_seqlens = torch.full((b,), mean_sk, dtype=torch.int32)\n",
    "    if varlen:\n",
    "        for i in range(b):\n",
    "            cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), s_q)\n",
    "    total_seqlens = cache_seqlens.sum().item()\n",
    "    mean_seqlens = cache_seqlens.float().mean().int().item()\n",
    "    max_seqlen = cache_seqlens.max().item()\n",
    "    max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256\n",
    "    # print(f\"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}\")\n",
    "\n",
    "    q = torch.randn(b, s_q, h_q, d)\n",
    "    block_size = 64\n",
    "    block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size)\n",
    "    blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d)\n",
    "    for i in range(b):\n",
    "        blocked_k.view(b, max_seqlen_pad, h_kv, d)[i, cache_seqlens[i].item():] = float(\"nan\")\n",
    "    blocked_v = blocked_k[..., :dv]\n",
    "\n",
    "    tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)\n",
    "\n",
    "    def flash_mla():\n",
    "        return flash_mla_with_kvcache(\n",
    "            q, blocked_k, block_table, cache_seqlens, dv,\n",
    "            tile_scheduler_metadata, num_splits, causal=causal,\n",
    "        )\n",
    "\n",
    "    def ref_mla():\n",
    "        out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32)\n",
    "        lse = torch.empty(b, h_q, s_q, dtype=torch.float32)\n",
    "        for i in range(b):\n",
    "            begin = i * max_seqlen_pad\n",
    "            end = begin + cache_seqlens[i]\n",
    "            O, LSE = scaled_dot_product_attention(\n",
    "                q[i].transpose(0, 1),\n",
    "                blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1),\n",
    "                blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1),\n",
    "                h_q=h_q,\n",
    "                h_kv=h_kv,\n",
    "                is_causal=causal,\n",
    "            )\n",
    "            out[i] = O.transpose(0, 1)\n",
    "            lse[i] = LSE\n",
    "        return out, lse\n",
    "\n",
    "    out_flash, lse_flash = flash_mla()\n",
    "    out_torch, lse_torch = ref_mla()\n",
    "    cal_diff(out_flash, out_torch, \"out\")\n",
    "    cal_diff(lse_flash, lse_torch, \"lse\")\n",
    "\n",
    "    t = triton.testing.do_bench(flash_mla)\n",
    "    FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2\n",
    "    bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8)\n",
    "    print(f\"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS, {bytes / 10 ** 6 / t:.0f} GB/s\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
